"""
Generate figures for the trilemma paper.
1. Z₁ divergence fan chart (EMU members to 2060)
2. Regime strain bar chart (2040 predicted CA deviation)
3. P(peg) evolution lines (forward counterfactual)
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path

PROJECT_DIR = Path(__file__).resolve().parent.parent
OUT_TABLES = PROJECT_DIR / "output" / "tables"
FIG_DIR = PROJECT_DIR / "paper" / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

MULTILATERAL_DATA = PROJECT_DIR.parent / "multilateral" / "followup" / "data" / "processed"

EUROZONE_JOIN = {
    'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
    'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
    'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
    'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
}
EUROZONE_ISO3 = set(EUROZONE_JOIN.keys())

# Consistent style
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 11,
    'axes.titlesize': 13,
    'axes.labelsize': 11,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.dpi': 150,
})


# ── Figure 1: Z₁ Divergence Fan Chart ──────────────────────────────

def fig_z1_divergence():
    """Fan chart showing EMU Z₁ trajectories with highlighted countries."""
    print("Figure 1: Z₁ Divergence...")

    fp = pd.read_csv(MULTILATERAL_DATA / "full_panel.csv")
    ez_fp = fp[fp['iso3'].isin(EUROZONE_ISO3)].copy()

    # Annual data 1990-2060
    years = range(1990, 2061)
    highlight = {'ITA': '#d62728', 'ESP': '#ff7f0e', 'DEU': '#1f77b4',
                 'FRA': '#2ca02c', 'IRL': '#9467bd', 'GRC': '#8c564b'}

    fig, ax = plt.subplots(figsize=(10, 6))

    # EMU mean
    emu_annual = ez_fp[ez_fp['year'].between(1990, 2060)].groupby('year')['Z_1'].mean()

    # Plot all countries in light gray
    for iso3 in sorted(EUROZONE_ISO3):
        cdata = ez_fp[(ez_fp['iso3'] == iso3) & (ez_fp['year'].between(1990, 2060))]
        if iso3 not in highlight:
            ax.plot(cdata['year'], cdata['Z_1'], color='#cccccc', linewidth=0.7,
                    alpha=0.6)

    # Plot EMU mean
    ax.plot(emu_annual.index, emu_annual.values, color='black', linewidth=2,
            linestyle='--', label='EMU mean', zorder=5)

    # Plot highlighted countries
    for iso3, color in highlight.items():
        cdata = ez_fp[(ez_fp['iso3'] == iso3) & (ez_fp['year'].between(1990, 2060))]
        ax.plot(cdata['year'], cdata['Z_1'], color=color, linewidth=2,
                label=iso3, zorder=6)

    # Vertical line at 2024 (end of observed data)
    ax.axvline(x=2024, color='gray', linestyle=':', linewidth=1, alpha=0.7)
    ax.text(2025, ax.get_ylim()[0] + 0.1, 'Projected', fontsize=9,
            color='gray', style='italic')

    ax.set_xlabel('Year')
    ax.set_ylabel('Z₁ (demographic aging factor)')
    ax.set_title('EMU Demographic Divergence: Z₁ Trajectories to 2060')
    ax.legend(loc='upper left', framealpha=0.9)
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    path = FIG_DIR / "fig1_z1_divergence.png"
    fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"  Saved: {path}")


# ── Figure 2: Regime Strain Bar Chart (2040) ───────────────────────

def fig_strain_bar():
    """Horizontal bar chart of predicted CA deviation from EMU mean in 2040."""
    print("Figure 2: Regime Strain Bar...")

    strain_df = pd.read_csv(OUT_TABLES / "phase9_regime_strain.csv")
    s2040 = strain_df[strain_df['year'] == 2040].sort_values('predicted_ca_dev')

    fig, ax = plt.subplots(figsize=(8, 7))

    colors = ['#d62728' if v < 0 else '#2ca02c' for v in s2040['predicted_ca_dev']]
    bars = ax.barh(s2040['iso3'], s2040['predicted_ca_dev'], color=colors,
                   edgecolor='white', linewidth=0.5)

    ax.axvline(x=0, color='black', linewidth=0.8)
    ax.set_xlabel('Predicted CA/GDP Deviation from EMU Mean (pp)')
    ax.set_title('EMU Regime Strain Index, 2040\n(Demographic Contribution to CA Imbalance)')
    ax.grid(True, axis='x', alpha=0.3)

    # Add value labels
    for bar, val in zip(bars, s2040['predicted_ca_dev']):
        x_pos = val + (0.2 if val >= 0 else -0.2)
        ha = 'left' if val >= 0 else 'right'
        ax.text(x_pos, bar.get_y() + bar.get_height()/2, f'{val:+.1f}',
                va='center', ha=ha, fontsize=8)

    fig.tight_layout()
    path = FIG_DIR / "fig2_regime_strain_2040.png"
    fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"  Saved: {path}")


# ── Figure 3: P(peg) Evolution ─────────────────────────────────────

def fig_ppeg_evolution():
    """Line chart showing P(peg) trajectories for EMU members."""
    print("Figure 3: P(peg) Evolution...")

    # Read from the forward counterfactual table
    cf_path = OUT_TABLES / "phase9_forward_counterfactual.md"
    text = cf_path.read_text()

    # Parse the markdown table
    rows = []
    for line in text.split('\n'):
        if line.startswith('| ') and not line.startswith('|:') and 'Country' not in line:
            parts = [p.strip() for p in line.split('|') if p.strip()]
            if len(parts) >= 5:
                try:
                    rows.append({
                        'iso3': parts[0],
                        2020: float(parts[1]),
                        2030: float(parts[2]),
                        2040: float(parts[3]),
                        2050: float(parts[4]),
                    })
                except ValueError:
                    continue

    cf_df = pd.DataFrame(rows).set_index('iso3')
    years = [2020, 2030, 2040, 2050]

    highlight = {'ITA': '#d62728', 'ESP': '#ff7f0e', 'DEU': '#1f77b4',
                 'GRC': '#8c564b', 'FRA': '#2ca02c', 'SVK': '#e377c2'}

    fig, ax = plt.subplots(figsize=(10, 6))

    # All countries in gray
    for iso3 in cf_df.index:
        if iso3 not in highlight:
            vals = [cf_df.loc[iso3, yr] for yr in years]
            ax.plot(years, vals, color='#cccccc', linewidth=0.8, alpha=0.6,
                    marker='o', markersize=3)

    # Highlighted countries
    for iso3, color in highlight.items():
        if iso3 in cf_df.index:
            vals = [cf_df.loc[iso3, yr] for yr in years]
            ax.plot(years, vals, color=color, linewidth=2, marker='o',
                    markersize=5, label=iso3, zorder=6)

    ax.axhline(y=0.5, color='black', linestyle='--', linewidth=1, alpha=0.5)
    ax.text(2051, 0.51, 'P=0.5\nthreshold', fontsize=8, color='gray')

    ax.set_xlabel('Year')
    ax.set_ylabel('Predicted P(peg)')
    ax.set_title('Forward Counterfactual: Would EMU Members Choose to Peg?\n'
                 '(Logit trained on non-eurozone OECD)')
    ax.set_xlim(2018, 2052)
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc='upper right', framealpha=0.9)
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    path = FIG_DIR / "fig3_ppeg_evolution.png"
    fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"  Saved: {path}")


# ── Figure 4: Eurozone CA Effect (coefficient comparison) ──────────

def fig_ca_coefficient_comparison():
    """Bar chart comparing Z₁ on CA across subsamples."""
    print("Figure 4: CA Coefficient Comparison...")

    data = {
        'Full panel': (18.76, 10.55),
        'Trilemma\nsample': (23.76, 11.62),
        'OECD\nfloaters': (-19.59, 28.67),
        'Non-OECD': (1.17, 13.97),
        'Eurozone\n(post-join)': (-214.25, 44.80),
        'EZ post-\ncrisis': (-160.09, 52.59),
    }

    labels = list(data.keys())
    coefs = [data[l][0] for l in labels]
    ses = [data[l][1] for l in labels]

    fig, ax = plt.subplots(figsize=(10, 5))

    colors = []
    for c in coefs:
        if abs(c) / ses[coefs.index(c)] > 1.96:
            colors.append('#d62728' if c < 0 else '#2ca02c')
        else:
            colors.append('#aaaaaa')

    x = range(len(labels))
    bars = ax.bar(x, coefs, color=colors, edgecolor='white', linewidth=0.5)
    ax.errorbar(x, coefs, yerr=[1.96 * s for s in ses], fmt='none',
                color='black', capsize=4, linewidth=1)

    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylabel('Z₁ Coefficient on CA/GDP')
    ax.set_title('The Regime-Contingent CA Effect:\nZ₁ Coefficient Across Subsamples')
    ax.grid(True, axis='y', alpha=0.3)

    # Significance annotations
    for i, (c, s) in enumerate(zip(coefs, ses)):
        t = abs(c) / s
        if t > 2.576: sig = '***'
        elif t > 1.96: sig = '**'
        elif t > 1.645: sig = '*'
        else: sig = 'ns'
        y_pos = c + (1.96 * s + 5) * (1 if c >= 0 else -1)
        ax.text(i, y_pos, sig, ha='center', fontsize=10, fontweight='bold')

    fig.tight_layout()
    path = FIG_DIR / "fig4_ca_coefficient_comparison.png"
    fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"  Saved: {path}")


def main():
    print("=" * 60)
    print("GENERATING FIGURES")
    print("=" * 60)

    fig_z1_divergence()
    fig_strain_bar()
    fig_ppeg_evolution()
    fig_ca_coefficient_comparison()

    print("\n" + "=" * 60)
    print(f"All figures saved to {FIG_DIR}")
    print("=" * 60)


if __name__ == '__main__':
    main()
